package hex.genmodel.algos.glm;

import hex.genmodel.ModelMojoReader;
import hex.genmodel.MojoReaderBackend;
import org.junit.Test;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;

import static org.junit.Assert.assertArrayEquals;

public class GlmMultinomialMojoModelTest {

  @Test
  public void testScore0() throws Exception {
    double[][] data = new double[][]{
      new double[]{3161, 23, 14, 228, 55, 912, 212, 210, 133, 2069, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
      new double[]{3346, 325, 11, 30, 5, 2620, 191, 227, 176, 649, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0},
      new double[]{3351, 354, 11, 60, 10, 2592, 202, 221, 157, 633, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0},
      new double[]{3350, 354, 11, 192, 53, 2348, 201, 220, 157, 543, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0},
      new double[]{3347, 109, 33, 60, 44, 1831, 254, 182, 27, 764, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
      new double[]{3325, 101, 35, 30, 22, 1806, 252, 170, 15, 785, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
      new double[]{3254, 49, 13, 67, 0, 1687, 225, 211, 118, 900, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0},
      new double[]{3204, 76, 6, 384, 5, 153, 228, 229, 136, 2089, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
      new double[]{2862, 320, 17, 85, 21, 1498, 174, 221, 186, 1273, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
      new double[]{2913, 53, 11, 589, 170, 1252, 227, 215, 121, 1515, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0},
      new double[]{2800, 80, 22, 90, 34, 1664, 243, 195, 71, 1536, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
      new double[]{3136, 267, 20, 255, 42, 190, 166, 244, 215, 2399, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
      new double[]{3222, 57, 14, 600, 94, 1283, 229, 210, 111, 1951, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
      new double[]{3141, 27, 27, 573, 223, 2200, 197, 170, 96, 2343, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
      new double[]{3352, 126, 31, 85, 50, 1915, 253, 202, 51, 768, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
      new double[]{3275, 71, 21, 60, -27, 1771, 238, 195, 79, 892, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}
    };

    double[][] expPreds = new double[][]{
      new double[]{0, 0.9027640125745652, 0.023381206510067937, 0, 0, 0, 0, 0.07385478091536198},
      new double[]{6, 0.02281238541232931, 0.0024025099965886104, 0, 0, 2.08535821810314E-4, 0, 0.9745762989173993},
      new double[]{6, 0.01797453697081243, 0.001862814626938943, 0, 0, 2.2738774133598447E-4, 0, 0.9799349797284347},
      new double[]{6, 0.04797335624870764, 5.052887475942211E-4, 0, 0, 1.1688787845222927E-4, 3.4960876111643645E-6, 0.9514005067507195},
      new double[]{6, 0.02197731283604224, 7.392104492356125E-4, 3.637252874230595E-6, 0, 0.002167956161012932, 5.620749361912734E-6, 0.9751062624949794},
      new double[]{6, 0.030221442746143787, 9.644875499204778E-4, 7.133019759124742E-6, 0, 0.0029706086395519223, 1.3416083955152668E-5, 0.9658229118628036},
      new double[]{6, 0.09262968950824312, 0.0019076463444671957, 3.235051482303505E-6, 0, 4.896307482433415E-4, 2.713048806131892E-5, 0.9049426677465835},
      new double[]{0, 0.4783834601180203, 0.34818119310460616, 0, 0, 0.001041000891336853, 0, 0.1723942891031241},
      new double[]{4, 0.285203289650867, 0.20318289949652776, 0.002320693909117664, 0, 0.4940913957138632, 0.014694834607381572, 5.067761757032798E-4},
      new double[]{1, 0.08880775619238421, 0.4571740745045857, 0.009873796003845896, 0, 0.41035922050293483, 0.03367219059928413, 1.1261524428370041E-4},
      new double[]{4, 0.050941104166928375, 0.30443643765737405, 0.001170209461701016, 1.499445643558575E-5, 0.5410803302179078, 0.10230343085847532, 5.3493181177871673E-5},
      new double[]{0, 0.7832690773053058, 0.2122329904457378, 0, 0, 0, 0, 0.004497932248955307},
      new double[]{0, 0.7309194242318252, 0.1264788627308783, 0, 0, 0.010901249273914232, 9.586620956748903E-6, 0.13169075760749563},
      new double[]{0, 0.857748128486184, 0.10652059215292309, 2.7529965252189086E-6, 0, 0.010802565671893282, 2.2013211538237387E-4, 0.0247058281765821},
      new double[]{6, 0.02032438335460342, 0.0010800516802881974, 3.3658562853189273E-6, 0, 0.0026220144452695753, 4.136521632021665E-6, 0.9759660480669061},
      new double[]{6, 0.049609361410939606, 0.0010574837639309924, 1.8824959169765657E-6, 0, 2.943415107690722E-4, 1.214332755660368E-5, 0.949024787436879}
    };

    GlmMultinomialMojoModel mojo = (GlmMultinomialMojoModel) ModelMojoReader.readFrom(new ClasspathReaderBackend());

    for (int i = 0; i < data.length; i++) {
      double[] mojoPreds = mojo.score0(data[i], new double[8]);
      assertArrayEquals(expPreds[i], mojoPreds, 0.000001);
    }
  }

  private static class ClasspathReaderBackend implements MojoReaderBackend {
    @Override
    public BufferedReader getTextFile(String filename) throws IOException {
      InputStream is = GlmMojoModelTest.class.getResourceAsStream("multinomial/" + filename);
      if(is == null) return null;
      return new BufferedReader(new InputStreamReader(is));
    }

    @Override
    public byte[] getBinaryFile(String filename) throws IOException {
      throw new UnsupportedOperationException("Unexpected call to getBinaryFile()");
    }

    @Override
    public boolean exists(String name) {
      throw new UnsupportedOperationException("Unexpected call to exists()");
    }
  }

}
